//
//
// Tencent is pleased to support the open source community by making tRPC available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company.
// All rights reserved.
//
// If you have downloaded a copy of the tRPC source code from Tencent,
// please note that tRPC source code is licensed under the  Apache 2.0 License,
// A copy of the Apache 2.0 License is included in this file.
//
//

package client

import (
	"context"

	"trpc.group/trpc-go/trpc-go/codec"
	"trpc.group/trpc-go/trpc-go/errs"
	icodec "trpc.group/trpc-go/trpc-go/internal/codec"
	"trpc.group/trpc-go/trpc-go/internal/report"
	"trpc.group/trpc-go/trpc-go/transport"
)

// Stream is the interface that performs streaming RPCs.
type Stream interface {
	// Send sends stream messages.
	Send(ctx context.Context, m interface{}) error
	// Recv receives stream messages.
	Recv(ctx context.Context) ([]byte, error)
	// Init initiates all stream related options.
	Init(ctx context.Context, opt ...Option) (*Options, error)
	// Invoke initiates the lower layer connection to build the stream.
	Invoke(ctx context.Context) error
	// Close closes the stream.
	Close(ctx context.Context) error
}

// DefaultStream is the default client Stream.
var DefaultStream = NewStream()

// NewStream is the function that returns a Stream.
var NewStream = func() Stream {
	return &stream{}
}

// stream is an implementation of Stream.
type stream struct {
	opts *Options
	client
}

// SendControl is the interface used for sender's flow control.
type SendControl interface {
	GetWindow(uint32) error
	UpdateWindow(uint32)
}

// RecvControl is the interface used for receiver's flow control.
type RecvControl interface {
	OnRecv(n uint32) error
}

// Send implements Stream.
// It serializes the message and sends it to server through stream transport.
// It's safe to call Recv and Send in different goroutines concurrently, but calling
// Send in different goroutines concurrently is not thread-safe.
func (s *stream) Send(ctx context.Context, m interface{}) (err error) {
	defer func() {
		if err != nil {
			s.opts.StreamTransport.Close(ctx)
		}
	}()

	msg := codec.Message(ctx)
	reqBodyBuf, err := serializeAndCompress(ctx, msg, m, s.opts)
	if err != nil {
		return err
	}

	// if m != nil, m is Data frame and sender flow control is needed.
	if m != nil && s.opts.SControl != nil {
		if err := s.opts.SControl.GetWindow(uint32(len(reqBodyBuf))); err != nil {
			return err
		}
	}
	// encode reqBodyBuf
	reqBuf, err := s.opts.Codec.Encode(msg, reqBodyBuf)
	if err != nil {
		return errs.NewFrameError(errs.RetClientEncodeFail, "client codec Encode: "+err.Error())
	}

	if err := s.opts.StreamTransport.Send(ctx, reqBuf); err != nil {
		return err
	}
	return nil
}

// Recv implements Stream.
// It decodes and decompresses the message and leaves serialization to upper layer.
// It's safe to call Recv and Send in different goroutines concurrently, but calling
// Send in different goroutines concurrently is not thread-safe.
func (s *stream) Recv(ctx context.Context) (buf []byte, err error) {
	defer func() {
		if err != nil {
			s.opts.StreamTransport.Close(ctx)
		}
	}()
	rspBuf, err := s.opts.StreamTransport.Recv(ctx)
	if err != nil {
		return nil, err
	}
	msg := codec.Message(ctx)
	rspBodyBuf, err := s.opts.Codec.Decode(msg, rspBuf)
	if err != nil {
		return nil, errs.NewFrameError(errs.RetClientDecodeFail, "client codec Decode: "+err.Error())
	}
	if err := msg.ClientRspErr(); err != nil {
		return nil, err
	}
	if len(rspBodyBuf) > 0 {
		compressType := msg.CompressType()
		if icodec.IsValidCompressType(s.opts.CurrentCompressType) {
			compressType = s.opts.CurrentCompressType
		}
		// decompress
		if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop {
			rspBodyBuf, err = codec.Decompress(compressType, rspBodyBuf)
			if err != nil {
				return nil, errs.NewFrameError(errs.RetClientDecodeFail, "client codec Decompress: "+err.Error())
			}
		}
	}
	return rspBodyBuf, nil
}

// Close implements Stream.
func (s *stream) Close(ctx context.Context) error {
	// Send Close message.
	return s.Send(ctx, nil)
}

// Init implements Stream.
func (s *stream) Init(ctx context.Context, opt ...Option) (*Options, error) {
	// The generic message structure data of the current request is retrieved from the context,
	// and each backend call uses a new msg generated by the client stub code.
	msg := codec.Message(ctx)

	// Get options.
	opts, err := s.getOptions(msg, opt...)
	if err != nil {
		return nil, err
	}

	// Update msg.
	s.updateMsg(msg, opts)

	// Select a node of backend service.
	node, err := selectNode(ctx, msg, opts)
	if err != nil {
		report.SelectNodeFail.Incr()
		return nil, err
	}
	ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr)
	const invalidCost = -1
	opts.Node.set(node, node.Address, invalidCost)
	if opts.Codec == nil {
		report.ClientCodecEmpty.Incr()
		return nil, errs.NewFrameError(errs.RetClientEncodeFail, "client: codec empty")
	}
	opts.CallOptions = append(opts.CallOptions, transport.WithMsg(msg))
	s.opts = opts
	return s.opts, nil
}

func findFirstNonEmpty(ss ...string) string {
	for _, s := range ss {
		if s != "" {
			return s
		}
	}
	return ""
}

// Invoke implements Stream.
func (s *stream) Invoke(ctx context.Context) error {
	return s.opts.StreamTransport.Init(ctx, s.opts.CallOptions...)
}
